{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-g5l3RLOGXcc"
      },
      "source": [
        "# Tutorial: Embedding StableHLO in SavedModel\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][savedmodel-tutorial-colab]\n",
        "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][savedmodel-tutorial-kaggle]\n",
        "\n",
        "_The [`stablehlo.savedmodel`][savedmodel-module] module._\n",
        "\n",
        "This tutorial will detail how to embed arbitrary StableHLO in a SavedModel. Note that most frameworks have specific APIs for emitting SavedModels, see other StableHLO tutorials for instructions on using these.\n",
        "\n",
        "## Tutorial Setup\n",
        "\n",
        "### Install required dependencies\n",
        "\n",
        "We'll be using the `stablehlo` nightly wheel to get StableHLO's Python APIs, and `tensorflow` for the [SavedModel][savedmodel-tf] dependency.\n",
        "\n",
        "[savedmodel-tf]: https://www.tensorflow.org/guide/saved_model\n",
        "[savedmodel-module]: https://github.com/openxla/stablehlo/tree/main/stablehlo/integrations/python/stablehlo/savedmodel\n",
        "[savedmodel-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/savedmodel-embed.ipynb\n",
        "[savedmodel-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/savedmodel-embed.ipynb"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B18oDm74-YoZ"
      },
      "outputs": [],
      "source": [
        "!pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels\n",
        "!pip install tensorflow-cpu"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eUjF3nxXGyuP"
      },
      "source": [
        "## Embed StableHLO model in SavedModel\n",
        "\n",
        "In this section we'll take a very basic StableHLO module, and demonstrate some of the APIs to embed it in a SavedModel. In practice this StableHLO module can come from a debug dump, an export from a framework, or even converted from HLO.\n",
        "\n",
        "### Define a StableHLO `add` module\n",
        "\n",
        "For this tutorial we'll use a simple `add` model with two input arguments `arg0` and `bias`. When packaging in SavedModel, `bias` will be a constant that is stored in the SavedModel, while `arg0` is provided when calling the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 48,
      "metadata": {
        "id": "uBygn_kU-kek"
      },
      "outputs": [],
      "source": [
        "MODULE_STRING = \"\"\"\n",
        "func.func @main(%arg0: tensor<1xf32>, %bias: tensor<1xf32>) -> tensor<1xf32> {\n",
        "  %0 = stablehlo.add %arg0, %bias: tensor<1xf32>\n",
        "  return %0 : tensor<1xf32>\n",
        "}\n",
        "\"\"\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0KEsCtO7IZY6"
      },
      "source": [
        "### Parse to a StableHLO MLIR Module\n",
        "\n",
        "Once we have a StableHLO file / dump of interest, we can parse it back to an MLIR module using `ir.Module.parse`.\n",
        "\n",
        "Note that all dialects in the module must be registered, otherwise `parse` will fail."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 54,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gTkLk_0E_HPz",
        "outputId": "6b66c9c0-3ca8-4f07-bf68-cf72f7fa235c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "module {\n",
            "  func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {\n",
            "    %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32>\n",
            "    return %0 : tensor<1xf32>\n",
            "  }\n",
            "}\n",
            "\n"
          ]
        }
      ],
      "source": [
        "import mlir.ir as ir\n",
        "import mlir.dialects.stablehlo as stablehlo\n",
        "\n",
        "with ir.Context() as ctx:\n",
        "  stablehlo.register_dialect(ctx)\n",
        "  module = ir.Module.parse(MODULE_STRING)\n",
        "\n",
        "print(module)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P-gvoHbtItkg"
      },
      "source": [
        "### Embed in SavedModel using `stablehlo_to_tf_saved_model`\n",
        "\n",
        "StableHLO's Python wheel includes a `savedmodel` module to help with packaging StableHLO in SavedModels.\n",
        "\n",
        "Packing in SavedModel requires a few details:\n",
        "\n",
        "**`input_locations`** specify where inputs to a model live, in the saved model (`InputLocation.parameter`) or passed in as input arguments during invocation (`InputLocation.input_arg`).\n",
        "\n",
        "**`state_dict`** can be used to specify values for the `parameter` arguments that live in the SavedModel. These are linked by `name`.\n",
        "\n",
        "In this example, we'll specify that the second input argument is a value with name `module.bias` which is stored in the SavedModel with the value `2`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 49,
      "metadata": {
        "id": "sMZUE-vj_Wg5"
      },
      "outputs": [],
      "source": [
        "from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import InputLocation\n",
        "import numpy as np\n",
        "\n",
        "input_locations = [\n",
        "    InputLocation.input_arg(position=0),          # Parameter, non-constant\n",
        "    InputLocation.parameter(name='module.bias'),  # Constant data in SavedModel\n",
        "]\n",
        "state_dict = {\n",
        "    'module.bias': np.array([2], dtype='float32'),\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eKArz9iCK4Pa"
      },
      "source": [
        "Now we can use `stablehlo_to_tf_saved_model` to create the SavedModel in a path specified using the `saved_model_dir` argument."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 52,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Q3shJyzHKlbo",
        "outputId": "48383f47-b083-4800-d4a4-0f7fd980d6c4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "assets\tfingerprint.pb\tsaved_model.pb\tvariables\n"
          ]
        }
      ],
      "source": [
        "from mlir.stablehlo.savedmodel.stablehlo_to_tf_saved_model import stablehlo_to_tf_saved_model\n",
        "\n",
        "stablehlo_to_tf_saved_model(\n",
        "    module,\n",
        "    saved_model_dir='/tmp/add_model',\n",
        "    input_locations=input_locations,\n",
        "    state_dict=state_dict,\n",
        ")\n",
        "\n",
        "!ls /tmp/add_model/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MboZEPwmLIYl"
      },
      "source": [
        "### Reload and call the SavedModel\n",
        "\n",
        "Now we can load that SavedModel and compile using a sample input.\n",
        "\n",
        "Here we'll just use a TF constant with the value `3`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 53,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "i4KLLZ7AGB6X",
        "outputId": "10fda712-819a-4002-d4e7-3702087294fd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>]\n"
          ]
        }
      ],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "restored_model = tf.saved_model.load('/tmp/add_model')\n",
        "print(restored_model.f(tf.constant([3], tf.float32)))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
